rm(list = ls())

## 

library(tidyverse)
library(lfe)

## Get data

df <- read_rds("data/Data_Clean.rds") %>% 
  mutate(year = as.numeric(year)) %>% 
  mutate(unit = coalesce(ags, vb_key)) %>% 
  mutate(treat_weakly = ifelse(treat_categ == 'schwach', 1, 0),
         treat_heavy = ifelse(treat_categ %in% c('schwer', 'sehr schwer'), 1, 0)) %>% 
  distinct(county_id, .keep_all = T)

## Additional county covars

bg_df <- read_rds("data/additional_county_covars.rds") %>% 
  dplyr::select(ags_2017, pop_total, pop_density, gdp_nominal_2016,
                relig_cath_2011, unemp_rate_tot, dist_to_state_capital,
                cdu_csu_party_17, spd_party_17, greens_party_17,
                wage_nominal2016) %>% 
  dplyr::rename(county_id = ags_2017)

## Merge

df <- df %>% 
  left_join(bg_df)

## Select variables

bvars <- c("pop_total", "pop_density", "gdp_nominal_2016",
           "relig_cath_2011", "unemp_rate_tot", "dist_to_state_capital",
           "cdu_csu_party_17", "spd_party_17", "greens_party_17",
           "wage_nominal2016")

## Balance function 

get_bal <- function (treatvar, cov_list, 
                     data.df, FE = NULL, weights = NULL) 
{
  out_temp <- lapply(cov_list, function(cv) {
    if (length(FE) == 0) {
      f <- as.formula(paste0(cv, " ~", treatvar))
    }
    else {
      f <- as.formula(paste0(cv, " ~", treatvar, 
                             "+", FE))
    }
    data.df[, cov_list] <- scale(data.df[, cov_list])
    if (!is.null(weights)) {
      m <- lm(f, data = data.df, weight = data.df %>% pull(!!weights))
    }
    else {
      m <- lm(f, data = data.df)
    }
    coef_list <- summary(m)$coefficients[2, 1]
    lower_list <- coef_list - 1.96 * summary(m)$coefficients[2, 
                                                             2]
    upper_list <- coef_list + 1.96 * summary(m)$coefficients[2, 
                                                             2]
    data.frame(cov = cv, lower = lower_list, upper = upper_list, 
               coef = coef_list, tv = treatvar, stringsAsFactors = F)
  })
  do.call("rbind", out_temp)
}

## Heavy vs. no treatment 

b_heavy = get_bal(treatvar = 'treat_heavy', cov_list = bvars,
                  data.df = df %>% 
                    filter(!treat_weakly == 1)) %>% 
  mutate(what = 'Highly affected\nvs. unaffected')

## Weak vs. no treatment 

b_weakly = get_bal(treatvar = 'treat_weakly', cov_list = bvars,
                   data.df = df %>% 
                     filter(!treat_heavy == 1)) %>% 
  mutate(what = 'Weakly affected\nvs. unaffected') 

## Combine and plot 

bal_combined <- b_heavy %>% 
  bind_rows(b_weakly)

## Dictionary

dict_df <- structure(list(var_label = c("sim_today", "pop_density", 
                                        "pop_total", 
                                        "relig_cath_2011", "gdp_nominal_2016", 
                                        "wage_nominal2016", "unemp_rate_tot", 
                                        "cdu_csu_party_13", 
                                        "commuters_capita_in_2017", 
                                        "dist_to_state_capital", 
                                        "any_prog_20s", "cdu_csu_party_17", 
                                        "spd_party_17", "greens_party_17"), 
                          var = c("Similarity to standard German", 
                                  "Population density, 2017", 
           "Total population, 2017", "Share of Catholic population, 2011", 
           "Nominal GDP/capita, 2016", "Nominal wage, 2016", 
           "Unemployment rate, 2017", 
           "CDU/CSU vote share, 2013", "Commuters per capita", 
           "Distance to resp. state capitals, 2017", 
           NA, "CDU/CSU vote share, 2017", "SPD vote share, 2017", 
           "Green vote share, 2017"), 
           order = c(NA, 2, 1, 3, 4, 5, 6, NA, NA, 7, NA, 8, 9, 10)), 
           row.names = c(NA, -14L), 
           class = c("tbl_df", "tbl", "data.frame"))

## Merge dictionary

bal_combined <- bal_combined %>% 
  left_join(dict_df %>% dplyr::rename(cov = var_label)) %>% 
  mutate(var = fct_reorder(var, order))

#### Figure A.8 - balance ####

p1 <- ggplot(bal_combined, aes(var, coef)) + 
  geom_hline(yintercept = 0, linetype = 'dotted') +
  geom_errorbar(aes(ymin = lower, ymax = upper),
                width = 0) + 
  geom_point(shape = 21, fill = 'white') + 
  facet_wrap(~what) + 
  theme_bw() + 
  ylab('Standardized difference') + 
  xlab("") + 
  coord_flip()
p1